home *** CD-ROM | disk | FTP | other *** search
/ MacHack 2001 / MacHack 2001.toast / pc / The Hacks / TiVo™ for QuicktimeTV™ / RTSP-playback.py next >
Encoding:
Python Source  |  2001-06-23  |  15.1 KB  |  570 lines

  1. #!/usr/local/bin/python
  2.  
  3. """
  4.  
  5. RTSP Proxy v1.2
  6. ---------------
  7. Jonathan Hogg <jonathan@onegoodidea.com>
  8.  
  9. Copyright (c) 1999 One Good Idea Limited <http://www.onegoodidea.com/>
  10.  
  11. Permission to use, copy, modify, and distribute this software and its
  12. documentation for any purpose, without fee, and without a written agreement
  13. is hereby granted, provided that the above copyright notice and this
  14. paragraph and the following two paragraphs appear in all copies.
  15.  
  16. IN NO EVENT SHALL ONE GOOD IDEA LIMITED BE LIABLE TO ANY PARTY FOR DIRECT, 
  17. INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST 
  18. PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, 
  19. EVEN IF ONE GOOD IDEA LIMITED HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH 
  20. DAMAGE.
  21.  
  22. ONE GOOD IDEA LIMITED SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, 
  23. BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 
  24. FOR A PARTICULAR PURPOSE.  THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS 
  25. IS" BASIS, AND ONE GOOD IDEA LIMITED HAS NO OBLIGATIONS TO PROVIDE 
  26. MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
  27.  
  28.  
  29. Usage:
  30.  
  31.     % RTSP_Proxy
  32.  
  33.  
  34. The proxy listens on port 7070 so that it doesn't need to be run as root
  35. to operate (although this can be easily changed down the bottom of the
  36. script). It is a very simple program and can get confused, but in it's
  37. present state is about as functional as Apple's rtsp_proxy but a lot less
  38. buggy.
  39.  
  40. """
  41.  
  42.  
  43. import cPickle
  44. import sys
  45. import string
  46. import StringIO
  47. import re
  48. import time
  49. from threading import *
  50.  
  51. from socket import *
  52. if not globals().has_key('IPPROTO_TCP'):
  53.     IPPROTO_TCP = 6
  54.  
  55. from select import *
  56.  
  57. import urlparse
  58. try:
  59.     if "rtsp" not in urlparse.uses_netloc:
  60.         urlparse.uses_netloc.append("rtsp")
  61. except:
  62.     pass
  63.  
  64.  
  65.  
  66. #------------------------------------------------------------------------
  67.  
  68. class Logger:
  69.  
  70.     def __init__( self, file = sys.stderr ):
  71.         self._lastmsg = ''
  72.         self._first = 1
  73.         self._repeats = 0
  74.         self._file = file
  75.         self._file.write( "[log started]" )
  76.         self._lock = Lock()
  77.     
  78.     def log( self, msg ):
  79.         self._lock.acquire()
  80.         if msg == self._lastmsg:
  81.             if self._repeats == 0:
  82.                 self._file.write( ' (.' )
  83.             self._file.write( '.' )
  84.             self._repeats = self._repeats + 1
  85.         else:
  86.             if self._repeats > 0:
  87.                 self._file.write( ')' )
  88.             self._file.write( '\n' )
  89.             self._first = 0
  90.             self._file.write( msg )
  91.             self._repeats = 0
  92.         self._file.flush()
  93.         if self._repeats == 75 - len(msg):
  94.             self._lastmsg = ''
  95.         else:
  96.             self._lastmsg = msg
  97.         self._lock.release()
  98.  
  99.  
  100. logger = Logger()
  101. debug = logger.log
  102.  
  103.  
  104. def makeportrange( ports ):
  105.  
  106.     if len(ports) == 1:
  107.         return "%d" % ports[0]
  108.     else:
  109.         return "%d-%d" % (ports[0], ports[-1])
  110.  
  111.  
  112.  
  113. #------------------------------------------------------------------------
  114.  
  115. def messageFromConnection(conn):
  116.     buffer = conn.recv( 10240 )
  117.     m = Message(buffer)
  118.     print("\n-----------------------------\n%s----------------------\n" % m.getmessage())
  119.     return m
  120.  
  121. class Message:
  122.  
  123.     def __init__( self, messageString ):
  124.  
  125.         self._input = StringIO.StringIO(messageString)
  126.         self._buffer = ""
  127.         self.readcommand()
  128.         self.readheaders()
  129.         self.readcontent()
  130.  
  131.     
  132.     
  133.     def readdata( self ):
  134.  
  135.         self._buffer = self._buffer + self._input.read(1024)
  136.  
  137.  
  138.     def getdata( self, length):
  139.     
  140.         while 1:
  141.             if len(self._buffer) >= length:
  142.                 data = self._buffer[0:length]
  143.                 self._buffer = self._buffer[length:]
  144.                 return data
  145.             else:
  146.                 self.readdata()
  147.  
  148.  
  149.     def readline( self ):
  150.     
  151.         while 1:
  152.             if self._buffer == "":
  153.                 self.readdata()
  154.             
  155.             pos = string.find( self._buffer, "\r\n" )
  156.         
  157.             if pos <> -1:
  158.                 line = self._buffer[:pos]
  159.                 self._buffer = self._buffer[pos+2:]
  160.                 return line
  161.         
  162.             self.readdata()
  163.  
  164.     def readcommand( self ):
  165.     
  166.         line = self.readline()
  167.         bits = string.split( line )
  168.         self._command = bits[0]
  169.         self._arguments = bits[1:]
  170.  
  171.  
  172.     def readheaders( self ):
  173.  
  174.         self._headerdict = {}
  175.         self._headerlist = []
  176.         
  177.         while 1:
  178.             line = self.readline()
  179.             if line == "":
  180.                 break
  181.             if line[0] in string.whitespace:
  182.                 header[1] = header[1] + string.lstrip(line)
  183.             else:
  184.                 (field,value) = string.split( line, ":", 1 )
  185.                 header = [field, string.strip(value)]
  186.                 self._headerlist.append( header )
  187.                 self._headerdict[string.lower(field)] = header
  188.  
  189.  
  190.     def readcontent( self ):
  191.         
  192.         length = self.getheader('content-length')
  193.         if length:
  194.             self._content = self.getdata( int(length) )
  195.         else:
  196.             self._content = ""
  197.  
  198.  
  199.     def getmessage( self ):
  200.     
  201.         msg = self._command + " " + string.join( self._arguments ) + "\r\n"
  202.         
  203.         for header in self._headerlist:
  204.             msg = msg + "%s: %s\r\n" % (header[0], header[1])
  205.  
  206.         msg = msg + "\r\n" + self._content
  207.         
  208.         return msg
  209.         
  210.  
  211.     def getheader( self, field ):
  212.         
  213.         name = string.lower( field )
  214.         if self._headerdict.has_key( name ):
  215.             return self._headerdict[name][1]
  216.         else:
  217.             return None
  218.  
  219.  
  220.     def setheader( self, field, value ):
  221.         
  222.         self._headerdict[string.lower(field)][1] = value
  223.  
  224.  
  225.     def getcommand( self ):
  226.  
  227.         return self._command
  228.     
  229.     
  230.     def setcommand( self, command ):
  231.  
  232.         self._command = command
  233.     
  234.     
  235.     def getargs( self ):
  236.     
  237.         return self._arguments
  238.  
  239.  
  240.     def setargs( self, args ):
  241.         self._arguments = args
  242.     
  243.     
  244.  
  245. #------------------------------------------------------------------------
  246.  
  247. class Session( Thread ):
  248.  
  249.     START_PORT = 40000
  250.     _currentport = START_PORT
  251.  
  252.     def __init__( self, conn, addr, _from, _to, archive ):
  253.     
  254.         Thread.__init__( self )
  255.         self._clientconn = conn
  256.         self._clientaddr = addr
  257.         self._from = _from
  258.         self._to = _to
  259.     self._portsMapping = {}
  260.     self._archive = archive
  261.     self.setDaemon( 1 )
  262.  
  263.     def _allocateports( self, howmany ):
  264.     
  265.         start = Session._currentport
  266.         sofar = 0
  267.         socks = []
  268.         
  269.         while sofar < howmany:
  270.         
  271.             sock = socket( AF_INET, SOCK_DGRAM )
  272.             port = Session._currentport
  273.             Session._currentport = Session._currentport + 1
  274.             
  275.             try:
  276.                 sock.bind( ('',port) )
  277.             except:
  278.                 sofar = 0
  279.                 start = self._currentport
  280.                 socks = []
  281.                 
  282.             socks.append( (port,sock) )
  283.             sofar = sofar + 1
  284.             end = port
  285.         
  286.         debug( "  allocated a port range at %d-%d" % (start,end) )
  287.  
  288.         return socks
  289.  
  290.     
  291.     def sendclientmsg( self, msg ):
  292.     debug("---------\nACTUALLY SENDING:\n" + msg.getmessage() + "---------------\n")
  293.         self._clientconn.send( msg.getmessage() )
  294.     
  295.     def getservermsg( self ):
  296.         return messageFromConnection( self._serverconn )
  297.  
  298.     def sendservermsg( self, msg ):
  299.         self._serverconn.send( msg.getmessage() )
  300.     
  301.     def dispatch( self, msg ):
  302.         command = msg.getcommand()
  303.         
  304.         debug( "GOT command: " + msg.getmessage() )
  305.         
  306.         if command == "DESCRIBE":
  307.             self.do_passthrough( msg )
  308.             
  309.         elif command == "SETUP":
  310.             self.do_setup( msg )
  311.             
  312.         elif command == "OPTIONS":
  313.             self.do_passthrough( msg )
  314.  
  315.         elif command == "PLAY":
  316.             self.do_play( msg )
  317.             # we need to start playing back data
  318.  
  319.         else:
  320.             self.sendservermsg( msg )
  321.             response = self.getservermsg()
  322.             self.sendclientmsg( response )
  323.     
  324.     
  325.     def do_options( self, msg ):
  326.     
  327.         if self._client_type[:4] == 'QTS/' and self._server_type == 'QTSS/v66':
  328.             debug( '  translating OPTIONS into a GET_PARAMETER ping for broken QuickTime' )
  329.             msg.setcommand( 'GET_PARAMETER' )
  330.             self.sendservermsg( msg )
  331.  
  332.             msg.setcommand( 'RTSP/1.0' )
  333.             msg.setargs( ['200', 'OK'] )
  334.             self.sendclientmsg( msg )
  335.         
  336.         else:
  337.             self.sendservermsg( msg )
  338.             response = self.getservermsg()
  339.             self.sendclientmsg( response )
  340.  
  341.     def do_passthrough( self, msg ):
  342.     command = msg.getcommand()
  343.     outMsg = self._to[command][0]
  344.     self._to[command] = self._to[command][1:]
  345.         self.sendclientmsg(outMsg)
  346.  
  347.     def do_play(self, msg):
  348.     # TODO -- start forcing out data in a timed way on the SETUP channels
  349.     self._playback = Playback(self._archive, self._clientaddr, self._portsMapping)
  350.     self._playback.start()
  351.     command = msg.getcommand()
  352.     outMsg = self._to[command][0]
  353.     self._to[command] = self._to[command][1:]
  354.     # change session
  355.     outMsg.setheader('session', msg.getheader('session'))
  356.         self.sendclientmsg(outMsg)
  357.  
  358.     def do_setup( self, msg ):
  359.     # TODO -- parse SETUP message
  360.     # grab corresponding archived SETUP message
  361.     # create output port pair
  362.     # modify server_port portion
  363.     # add track, index mapping
  364.  
  365.         client_port = ''
  366.         
  367.         debug( "  client requests of proxy:\n    %s" % msg.getheader('transport') )
  368.         
  369.         for bit in string.split( msg.getheader('transport'), ";" ):
  370.             bit = string.strip( bit )
  371.             
  372.             if string.find( bit, '=' ) > 0:
  373.                 name, value = string.split( bit, '=', 1 )
  374.             
  375.                 if name == 'client_port':
  376.                     client_port = value
  377.  
  378.         if string.find( client_port, "-" ):
  379.             startport,endport = string.split( client_port, "-" )
  380.             clientports = range( int(startport), int(endport) + 1 )
  381.         else:
  382.             clientports = [ int(client_port) ]
  383.  
  384.     # create a port range
  385.     portRange = self._allocateports(len(clientports))
  386.  
  387.     # recover old SETUP response (should be indexed by track ID)
  388.     command = msg.getcommand()
  389.     print "1->", self._to[command]
  390.     print "2->", self._to[command][0]
  391.     response = self._to[command][0]
  392.     self._to[command] = self._to[command][1:]
  393.     debug("OLD: ->" + str(response.getmessage()))
  394.  
  395.     # dig out track ID
  396.     URI = msg.getargs()[0]
  397.     trackID = 0 # TODO -- fix!!            
  398.     m = re.search("trackID=(\d+)", URI)
  399.     if m != None: trackID = int(m.group(1))
  400.     print "track ID = ", trackID
  401.  
  402.     # change session
  403.     response.setheader('session', msg.getheader('session'))
  404.  
  405.     # salt away port mappings
  406.     for index in range(len(clientports)):
  407.         self._portsMapping[(trackID,index)] = (portRange[index][1],clientports[index])
  408.  
  409.     # TODO send hacked response
  410.  
  411.         debug( "  server offers to proxy:\n    " + response.getheader('transport') )
  412.         
  413.     aPortRange = map(lambda x: x[0], portRange)
  414.         response.setheader( 'transport',
  415.                             'RTP/AVP;unicast;client_port=%s;server_port=%s' % (client_port,
  416.                                 makeportrange(aPortRange)))
  417.  
  418.         debug( "  playback offers to client:\n    " + response.getheader('transport') )
  419.         
  420.         self.sendclientmsg( response )
  421.  
  422.  
  423.     def run( self ):
  424.     
  425. #        try:
  426.             while 1:
  427.                 msg = messageFromConnection( self._clientconn )
  428.                 self.dispatch( msg )
  429.         
  430. #        except:
  431.             debug( "taking down session" )
  432.             self._clientconn.close()
  433.             if self._serverconn:
  434.                 self._serverconn.close()
  435.  
  436.     def stop(self):
  437.     self._clientconn.close()
  438.  
  439.  
  440.  
  441. #------------------------------------------------------------------------
  442.  
  443. class Listener:
  444.  
  445.  
  446.     def __init__( self, port ):
  447.     
  448.         self._sock = socket( AF_INET, SOCK_STREAM )
  449.         self._sock.bind( ('',port) )
  450.         self._sock.setsockopt( IPPROTO_TCP, SO_REUSEADDR, 1 )
  451.         self._sock.listen( 5 )
  452.  
  453.     def waitforclient( self , _from, to, archive):
  454.     
  455.         conn, addr = self._sock.accept()
  456.         debug( "accepted connection from %s:%d" % addr )
  457.         return Session( conn, addr[0], _from, to, archive )
  458.  
  459.     def stop( self ):
  460.         self._sock.close()
  461.  
  462.  
  463.  
  464. #------------------------------------------------------------------------
  465.  
  466. class Playback( Thread ):
  467.  
  468.     def __init__( self, archive, clientaddr, portsByTrackID ):
  469.         # portsByTrackID: hash with keys of the form (trackID, index)
  470.     # and values of form (socket, clientPort)
  471.         Thread.__init__( self )
  472.     archive.seek(0)
  473.         self._archive = archive
  474.     self._clientaddr = clientaddr
  475.         self._portsByTrackID = portsByTrackID
  476.     self.setDaemon( 1 )
  477.  
  478.     
  479.     def doplayback(self):
  480.     time.sleep(1)
  481.     self.startTime = time.time()
  482.     self.archiveStartTime = None
  483.     while 1:
  484.         m = cPickle.load(self._archive)
  485.         # look for a 'data'
  486.         if m[1] != 'data': continue
  487.         track = m[2]
  488.         index = m[3]
  489.         playTime = m[0]
  490.         data = m[4]
  491.         if self.archiveStartTime == None: self.archiveStartTime = playTime
  492.         archiveDifference = playTime - self.archiveStartTime
  493.         currentDifference = time.time() - self.startTime
  494.         (socket, clientPort) = self._portsByTrackID[(track,index)]
  495.         print "data len = %d; clientaddr = %s; client port = %d" % (len(data), str(self._clientaddr), clientPort)
  496.         timeToWait = archiveDifference - currentDifference
  497.         if timeToWait > 0.1:
  498.             time.sleep(timeToWait)
  499.         sys.stdout.write(".")
  500.         sys.stdout.flush()
  501.         bytesWrit = socket.sendto(data, (self._clientaddr, clientPort))
  502.  
  503.     def run( self ):
  504.         debug( "  starting RTP playback: " + str(self._portsByTrackID))
  505.     self.doplayback()    
  506.  
  507.  
  508.  
  509. #------------------------------------------------------------------------
  510.  
  511. def getRTSPMessages(f):
  512.         fromClientMessages = {}
  513.         toClientMessages = {}
  514.     fromPending = []
  515.         try:
  516.                 while 1:
  517.                         ob = cPickle.load(f)
  518.                         if ob[1] == 'messageFromClient':
  519.                 m = Message(ob[2])
  520.                 command = m.getcommand()
  521.                 if not fromClientMessages.has_key(command): fromClientMessages[command] = []
  522.                 fromClientMessages[command].append(m)
  523.                 fromPending.append(command)
  524.                         if ob[1] == 'messageToClient':
  525.                 m = Message(ob[2])
  526.                 command = fromPending[0]
  527.                 fromPending = fromPending[1:]
  528.                 if not toClientMessages.has_key(command): toClientMessages[command] = []
  529.                 toClientMessages[command].append(m)
  530.         except Exception, e:
  531.                 print 'exception ',e
  532.                 # show traceback information
  533.                 tb = sys.exc_info()[2]
  534.                 print("Exception: " + str(sys.exc_info()[0]) + " line " + str(tb))
  535.     print fromClientMessages,toClientMessages
  536.     return fromClientMessages,toClientMessages
  537.  
  538.  
  539.  
  540. def main( argv ):
  541.     while 1:
  542.       try:
  543.         listener = Listener( 7272 )
  544.         break
  545.       except:
  546.     print "sleeping 5"
  547.         time.sleep(5)
  548.  
  549.  
  550.     archive = open(argv[1], "r")
  551.  
  552.     _from,to = getRTSPMessages(archive)
  553.     
  554.     debug( "waiting for a client" )
  555.  
  556.     try:
  557.         while 1:
  558.             client = listener.waitforclient(_from,to, archive)
  559.             listener.stop()
  560.             client.start()
  561.             while 1: time.sleep(100)
  562.     finally:
  563.         listener.stop()
  564.     print "Stopping client"
  565.     client.stop()
  566.  
  567.  
  568. if __name__ == "__main__":
  569.     main( sys.argv )
  570.